1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

HiDream Image (#11231)

* HiDream Image


---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Aryan <contact.aryanvs@gmail.com>
Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
hlky
2025-04-11 17:31:34 +01:00
committed by GitHub
parent bc261058ee
commit 0ef29355c9
15 changed files with 1976 additions and 1 deletions

View File

@@ -175,7 +175,7 @@
title: gguf
- local: quantization/torchao
title: torchao
- local: quantization/quanto
- local: quantization/quanto
title: quanto
title: Quantization Methods
- sections:
@@ -300,6 +300,8 @@
title: EasyAnimateTransformer3DModel
- local: api/models/flux_transformer
title: FluxTransformer2DModel
- local: api/models/hidream_image_transformer
title: HiDreamImageTransformer2DModel
- local: api/models/hunyuan_transformer2d
title: HunyuanDiT2DModel
- local: api/models/hunyuan_video_transformer_3d
@@ -446,6 +448,8 @@
title: Flux
- local: api/pipelines/control_flux_inpaint
title: FluxControlInpaint
- local: api/pipelines/hidream
title: HiDream-I1
- local: api/pipelines/hunyuandit
title: Hunyuan-DiT
- local: api/pipelines/hunyuan_video

View File

@@ -0,0 +1,30 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# HiDreamImageTransformer2DModel
A Transformer model for image-like data from [HiDream-I1](https://huggingface.co/HiDream-ai).
The model can be loaded with the following code snippet.
```python
from diffusers import HiDreamImageTransformer2DModel
transformer = HiDreamImageTransformer2DModel.from_pretrained("HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## HiDreamImageTransformer2DModel
[[autodoc]] HiDreamImageTransformer2DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput

View File

@@ -0,0 +1,43 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# HiDreamImage
[HiDream-I1](https://huggingface.co/HiDream-ai) by HiDream.ai
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## Available models
The following models are available for the [`HiDreamImagePipeline`](text-to-image) pipeline:
| Model name | Description |
|:---|:---|
| [`HiDream-ai/HiDream-I1-Full`](https://huggingface.co/HiDream-ai/HiDream-I1-Full) | - |
| [`HiDream-ai/HiDream-I1-Dev`](https://huggingface.co/HiDream-ai/HiDream-I1-Dev) | - |
| [`HiDream-ai/HiDream-I1-Fast`](https://huggingface.co/HiDream-ai/HiDream-I1-Fast) | - |
## HiDreamImagePipeline
[[autodoc]] HiDreamImagePipeline
- all
- __call__
## HiDreamImagePipelineOutput
[[autodoc]] pipelines.hidream_image.pipeline_output.HiDreamImagePipelineOutput

View File

@@ -171,6 +171,7 @@ else:
"FluxControlNetModel",
"FluxMultiControlNetModel",
"FluxTransformer2DModel",
"HiDreamImageTransformer2DModel",
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel",
@@ -368,6 +369,7 @@ else:
"FluxInpaintPipeline",
"FluxPipeline",
"FluxPriorReduxPipeline",
"HiDreamImagePipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline",
@@ -761,6 +763,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxControlNetModel,
FluxMultiControlNetModel,
FluxTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
HunyuanDiT2DMultiControlNetModel,
@@ -937,6 +940,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxInpaintPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
HiDreamImagePipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,

View File

@@ -76,6 +76,7 @@ if is_torch_available():
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
@@ -151,6 +152,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
DualTransformer2DModel,
EasyAnimateTransformer3DModel,
FluxTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanVideoTransformer3DModel,
LatteTransformer3DModel,

View File

@@ -21,6 +21,7 @@ if is_torch_available():
from .transformer_cogview4 import CogView4Transformer2DModel
from .transformer_easyanimate import EasyAnimateTransformer3DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel

View File

@@ -0,0 +1,896 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention
from ..embeddings import TimestepEmbedding, Timesteps
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class HiDreamImageFeedForwardSwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
class HiDreamImagePooledEmbed(nn.Module):
def __init__(self, text_emb_dim, hidden_size):
super().__init__()
self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size)
def forward(self, pooled_embed: torch.Tensor) -> torch.Tensor:
return self.pooled_embedder(pooled_embed)
class HiDreamImageTimestepEmbed(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None):
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
t_emb = self.timestep_embedder(t_emb)
return t_emb
class HiDreamImageOutEmbed(nn.Module):
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
shift, scale = self.adaLN_modulation(temb).chunk(2, dim=1)
hidden_states = self.norm_final(hidden_states) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
hidden_states = self.linear(hidden_states)
return hidden_states
class HiDreamImagePatchEmbed(nn.Module):
def __init__(
self,
patch_size=2,
in_channels=4,
out_channels=1024,
):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels
self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
def forward(self, latent):
latent = self.proj(latent)
return latent
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape
out = torch.einsum("...n,d->...nd", pos, omega)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.float()
class HiDreamImageEmbedND(nn.Module):
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(2)
def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
@maybe_allow_in_graph
class HiDreamAttention(Attention):
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
upcast_attention: bool = False,
upcast_softmax: bool = False,
scale_qk: bool = True,
eps: float = 1e-5,
processor=None,
out_dim: int = None,
single: bool = False,
):
super(Attention, self).__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.out_dim = out_dim if out_dim is not None else query_dim
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
self.sliceable_head_dim = heads
self.single = single
self.to_q = nn.Linear(query_dim, self.inner_dim)
self.to_k = nn.Linear(self.inner_dim, self.inner_dim)
self.to_v = nn.Linear(self.inner_dim, self.inner_dim)
self.to_out = nn.Linear(self.inner_dim, self.out_dim)
self.q_rms_norm = nn.RMSNorm(self.inner_dim, eps)
self.k_rms_norm = nn.RMSNorm(self.inner_dim, eps)
if not single:
self.to_q_t = nn.Linear(query_dim, self.inner_dim)
self.to_k_t = nn.Linear(self.inner_dim, self.inner_dim)
self.to_v_t = nn.Linear(self.inner_dim, self.inner_dim)
self.to_out_t = nn.Linear(self.inner_dim, self.out_dim)
self.q_rms_norm_t = nn.RMSNorm(self.inner_dim, eps)
self.k_rms_norm_t = nn.RMSNorm(self.inner_dim, eps)
self.set_processor(processor)
def forward(
self,
norm_hidden_states: torch.Tensor,
hidden_states_masks: torch.Tensor = None,
norm_encoder_hidden_states: torch.Tensor = None,
image_rotary_emb: torch.Tensor = None,
) -> torch.Tensor:
return self.processor(
self,
hidden_states=norm_hidden_states,
hidden_states_masks=hidden_states_masks,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
class HiDreamAttnProcessor:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __call__(
self,
attn: HiDreamAttention,
hidden_states: torch.Tensor,
hidden_states_masks: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None,
*args,
**kwargs,
) -> torch.Tensor:
dtype = hidden_states.dtype
batch_size = hidden_states.shape[0]
query_i = attn.q_rms_norm(attn.to_q(hidden_states)).to(dtype=dtype)
key_i = attn.k_rms_norm(attn.to_k(hidden_states)).to(dtype=dtype)
value_i = attn.to_v(hidden_states)
inner_dim = key_i.shape[-1]
head_dim = inner_dim // attn.heads
query_i = query_i.view(batch_size, -1, attn.heads, head_dim)
key_i = key_i.view(batch_size, -1, attn.heads, head_dim)
value_i = value_i.view(batch_size, -1, attn.heads, head_dim)
if hidden_states_masks is not None:
key_i = key_i * hidden_states_masks.view(batch_size, -1, 1, 1)
if not attn.single:
query_t = attn.q_rms_norm_t(attn.to_q_t(encoder_hidden_states)).to(dtype=dtype)
key_t = attn.k_rms_norm_t(attn.to_k_t(encoder_hidden_states)).to(dtype=dtype)
value_t = attn.to_v_t(encoder_hidden_states)
query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
num_image_tokens = query_i.shape[1]
num_text_tokens = query_t.shape[1]
query = torch.cat([query_i, query_t], dim=1)
key = torch.cat([key_i, key_t], dim=1)
value = torch.cat([value_i, value_t], dim=1)
else:
query = query_i
key = key_i
value = value_i
if query.shape[-1] == image_rotary_emb.shape[-3] * 2:
query, key = apply_rope(query, key, image_rotary_emb)
else:
query_1, query_2 = query.chunk(2, dim=-1)
key_1, key_2 = key.chunk(2, dim=-1)
query_1, key_1 = apply_rope(query_1, key_1, image_rotary_emb)
query = torch.cat([query_1, query_2], dim=-1)
key = torch.cat([key_1, key_2], dim=-1)
hidden_states = F.scaled_dot_product_attention(
query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if not attn.single:
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
hidden_states_i = attn.to_out(hidden_states_i)
hidden_states_t = attn.to_out_t(hidden_states_t)
return hidden_states_i, hidden_states_t
else:
hidden_states = attn.to_out(hidden_states)
return hidden_states
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
class MoEGate(nn.Module):
def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01):
super().__init__()
self.top_k = num_activated_experts
self.n_routed_experts = num_routed_experts
self.scoring_func = "softmax"
self.alpha = aux_loss_alpha
self.seq_aux = False
# topk selection algorithm
self.norm_topk_prob = False
self.gating_dim = embed_dim
self.weight = nn.Parameter(torch.randn(self.n_routed_experts, self.gating_dim) / embed_dim**0.5)
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
# print(bsz, seq_len, h)
### compute gating score
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == "softmax":
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")
### select top-k experts
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
### norm gate to sum 1
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
### expert-level computation auxiliary loss
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
# always compute aux loss based on the naive greedy topk method
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(
1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)
).div_(seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = None
return topk_idx, topk_weight, aux_loss
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
class MOEFeedForwardSwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
num_routed_experts: int,
num_activated_experts: int,
):
super().__init__()
self.shared_experts = HiDreamImageFeedForwardSwiGLU(dim, hidden_dim // 2)
self.experts = nn.ModuleList(
[HiDreamImageFeedForwardSwiGLU(dim, hidden_dim) for i in range(num_routed_experts)]
)
self.gate = MoEGate(
embed_dim=dim, num_routed_experts=num_routed_experts, num_activated_experts=num_activated_experts
)
self.num_activated_experts = num_activated_experts
def forward(self, x):
wtype = x.dtype
identity = x
orig_shape = x.shape
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
x = x.repeat_interleave(self.num_activated_experts, dim=0)
y = torch.empty_like(x, dtype=wtype)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape).to(dtype=wtype)
# y = AddAuxiliaryLoss.apply(y, aux_loss)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
y = y + self.shared_experts(identity)
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.num_activated_experts
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
# for fp16 and other dtype
expert_cache = expert_cache.to(expert_out.dtype)
expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce="sum")
return expert_cache
class TextProjection(nn.Module):
def __init__(self, in_features, hidden_size):
super().__init__()
self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
def forward(self, caption):
hidden_states = self.linear(caption)
return hidden_states
@maybe_allow_in_graph
class HiDreamImageSingleTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True))
# 1. Attention
self.norm1_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
self.attn1 = HiDreamAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
processor=HiDreamAttnProcessor(),
single=True,
)
# 3. Feed-forward
self.norm3_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
if num_routed_experts > 0:
self.ff_i = MOEFeedForwardSwiGLU(
dim=dim,
hidden_dim=4 * dim,
num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts,
)
else:
self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
def forward(
self,
hidden_states: torch.Tensor,
hidden_states_masks: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None,
) -> torch.Tensor:
wtype = hidden_states.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = self.adaLN_modulation(temb)[
:, None
].chunk(6, dim=-1)
# 1. MM-Attention
norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
attn_output_i = self.attn1(
norm_hidden_states,
hidden_states_masks,
image_rotary_emb=image_rotary_emb,
)
hidden_states = gate_msa_i * attn_output_i + hidden_states
# 2. Feed-forward
norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states.to(dtype=wtype))
hidden_states = ff_output_i + hidden_states
return hidden_states
@maybe_allow_in_graph
class HiDreamImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 12 * dim, bias=True))
# 1. Attention
self.norm1_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
self.norm1_t = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
self.attn1 = HiDreamAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
processor=HiDreamAttnProcessor(),
single=False,
)
# 3. Feed-forward
self.norm3_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
if num_routed_experts > 0:
self.ff_i = MOEFeedForwardSwiGLU(
dim=dim,
hidden_dim=4 * dim,
num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts,
)
else:
self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
self.norm3_t = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
self.ff_t = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
def forward(
self,
hidden_states: torch.Tensor,
hidden_states_masks: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None,
) -> torch.Tensor:
wtype = hidden_states.dtype
(
shift_msa_i,
scale_msa_i,
gate_msa_i,
shift_mlp_i,
scale_mlp_i,
gate_mlp_i,
shift_msa_t,
scale_msa_t,
gate_msa_t,
shift_mlp_t,
scale_mlp_t,
gate_mlp_t,
) = self.adaLN_modulation(temb)[:, None].chunk(12, dim=-1)
# 1. MM-Attention
norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
norm_encoder_hidden_states = self.norm1_t(encoder_hidden_states).to(dtype=wtype)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_msa_t) + shift_msa_t
attn_output_i, attn_output_t = self.attn1(
norm_hidden_states,
hidden_states_masks,
norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
hidden_states = gate_msa_i * attn_output_i + hidden_states
encoder_hidden_states = gate_msa_t * attn_output_t + encoder_hidden_states
# 2. Feed-forward
norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
norm_encoder_hidden_states = self.norm3_t(encoder_hidden_states).to(dtype=wtype)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_mlp_t) + shift_mlp_t
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states)
ff_output_t = gate_mlp_t * self.ff_t(norm_encoder_hidden_states)
hidden_states = ff_output_i + hidden_states
encoder_hidden_states = ff_output_t + encoder_hidden_states
return hidden_states, encoder_hidden_states
class HiDreamBlock(nn.Module):
def __init__(self, block: Union[HiDreamImageTransformerBlock, HiDreamImageSingleTransformerBlock]):
super().__init__()
self.block = block
def forward(
self,
hidden_states: torch.Tensor,
hidden_states_masks: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None,
) -> torch.Tensor:
return self.block(
hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["HiDreamImageTransformerBlock", "HiDreamImageSingleTransformerBlock"]
@register_to_config
def __init__(
self,
patch_size: Optional[int] = None,
in_channels: int = 64,
out_channels: Optional[int] = None,
num_layers: int = 16,
num_single_layers: int = 32,
attention_head_dim: int = 128,
num_attention_heads: int = 20,
caption_channels: List[int] = None,
text_emb_dim: int = 2048,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
axes_dims_rope: Tuple[int, int] = (32, 32),
max_resolution: Tuple[int, int] = (128, 128),
llama_layers: List[int] = None,
):
super().__init__()
self.out_channels = out_channels or in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.llama_layers = llama_layers
self.t_embedder = HiDreamImageTimestepEmbed(self.inner_dim)
self.p_embedder = HiDreamImagePooledEmbed(text_emb_dim, self.inner_dim)
self.x_embedder = HiDreamImagePatchEmbed(
patch_size=patch_size,
in_channels=in_channels,
out_channels=self.inner_dim,
)
self.pe_embedder = HiDreamImageEmbedND(theta=10000, axes_dim=axes_dims_rope)
self.double_stream_blocks = nn.ModuleList(
[
HiDreamBlock(
HiDreamImageTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts,
)
)
for _ in range(self.config.num_layers)
]
)
self.single_stream_blocks = nn.ModuleList(
[
HiDreamBlock(
HiDreamImageSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts,
)
)
for _ in range(self.config.num_single_layers)
]
)
self.final_layer = HiDreamImageOutEmbed(self.inner_dim, patch_size, self.out_channels)
caption_channels = [
caption_channels[1],
] * (num_layers + num_single_layers) + [
caption_channels[0],
]
caption_projection = []
for caption_channel in caption_channels:
caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim))
self.caption_projection = nn.ModuleList(caption_projection)
self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
def expand_timesteps(self, timesteps, batch_size, device):
if not torch.is_tensor(timesteps):
is_mps = device.type == "mps"
if isinstance(timesteps, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(batch_size)
return timesteps
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
if is_training:
B, S, F = x.shape
C = F // (self.config.patch_size * self.config.patch_size)
x = (
x.reshape(B, S, self.config.patch_size, self.config.patch_size, C)
.permute(0, 4, 1, 2, 3)
.reshape(B, C, S, self.config.patch_size * self.config.patch_size)
)
else:
x_arr = []
p1 = self.config.patch_size
p2 = self.config.patch_size
for i, img_size in enumerate(img_sizes):
pH, pW = img_size
t = x[i, : pH * pW].reshape(1, pH, pW, -1)
F_token = t.shape[-1]
C = F_token // (p1 * p2)
t = t.reshape(1, pH, pW, p1, p2, C)
t = t.permute(0, 5, 1, 3, 2, 4)
t = t.reshape(1, C, pH * p1, pW * p2)
x_arr.append(t)
x = torch.cat(x_arr, dim=0)
return x
def patchify(self, x, max_seq, img_sizes=None):
pz2 = self.config.patch_size * self.config.patch_size
if isinstance(x, torch.Tensor):
B, C = x.shape[0], x.shape[1]
device = x.device
dtype = x.dtype
else:
B, C = len(x), x[0].shape[0]
device = x[0].device
dtype = x[0].dtype
x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
if img_sizes is not None:
for i, img_size in enumerate(img_sizes):
x_masks[i, 0 : img_size[0] * img_size[1]] = 1
B, C, S, _ = x.shape
x = x.permute(0, 2, 3, 1).reshape(B, S, pz2 * C)
elif isinstance(x, torch.Tensor):
B, C, Hp1, Wp2 = x.shape
pH, pW = Hp1 // self.config.patch_size, Wp2 // self.config.patch_size
x = x.reshape(B, C, pH, self.config.patch_size, pW, self.config.patch_size)
x = x.permute(0, 2, 4, 3, 5, 1)
x = x.reshape(B, pH * pW, self.config.patch_size * self.config.patch_size * C)
img_sizes = [[pH, pW]] * B
x_masks = None
else:
raise NotImplementedError
return x, x_masks, img_sizes
def forward(
self,
hidden_states: torch.Tensor,
timesteps: torch.LongTensor = None,
encoder_hidden_states: torch.Tensor = None,
pooled_embeds: torch.Tensor = None,
img_sizes: Optional[List[Tuple[int, int]]] = None,
img_ids: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# spatial forward
batch_size = hidden_states.shape[0]
hidden_states_type = hidden_states.dtype
if hidden_states.shape[-2] != hidden_states.shape[-1]:
B, C, H, W = hidden_states.shape
patch_size = self.config.patch_size
pH, pW = H // patch_size, W // patch_size
out = torch.zeros(
(B, C, self.max_seq, patch_size * patch_size),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
hidden_states = hidden_states.reshape(B, C, pH, patch_size, pW, patch_size)
hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
hidden_states = hidden_states.reshape(B, C, pH * pW, patch_size * patch_size)
out[:, :, 0 : pH * pW] = hidden_states
hidden_states = out
# 0. time
timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
timesteps = self.t_embedder(timesteps, hidden_states_type)
p_embedder = self.p_embedder(pooled_embeds)
temb = timesteps + p_embedder
hidden_states, hidden_states_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
if hidden_states_masks is None:
pH, pW = img_sizes[0]
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
img_ids = (
img_ids.reshape(img_ids.shape[0] * img_ids.shape[1], img_ids.shape[2])
.unsqueeze(0)
.repeat(batch_size, 1, 1)
)
hidden_states = self.x_embedder(hidden_states)
T5_encoder_hidden_states = encoder_hidden_states[0]
encoder_hidden_states = encoder_hidden_states[-1]
encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
if self.caption_projection is not None:
new_encoder_hidden_states = []
for i, enc_hidden_state in enumerate(encoder_hidden_states):
enc_hidden_state = self.caption_projection[i](enc_hidden_state)
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
new_encoder_hidden_states.append(enc_hidden_state)
encoder_hidden_states = new_encoder_hidden_states
T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states.append(T5_encoder_hidden_states)
txt_ids = torch.zeros(
batch_size,
encoder_hidden_states[-1].shape[1]
+ encoder_hidden_states[-2].shape[1]
+ encoder_hidden_states[0].shape[1],
3,
device=img_ids.device,
dtype=img_ids.dtype,
)
ids = torch.cat((img_ids, txt_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids)
# 2. Blocks
block_id = 0
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
for bid, block in enumerate(self.double_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
cur_encoder_hidden_states = torch.cat(
[initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
hidden_states_masks,
cur_encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
hidden_states, initial_encoder_hidden_states = block(
hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks,
encoder_hidden_states=cur_encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
block_id += 1
image_tokens_seq_len = hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
hidden_states_seq_len = hidden_states.shape[1]
if hidden_states_masks is not None:
encoder_attention_mask_ones = torch.ones(
(batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
device=hidden_states_masks.device,
dtype=hidden_states_masks.dtype,
)
hidden_states_masks = torch.cat([hidden_states_masks, encoder_attention_mask_ones], dim=1)
for bid, block in enumerate(self.single_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
hidden_states_masks,
None,
temb,
image_rotary_emb,
)
else:
hidden_states = block(
hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks,
encoder_hidden_states=None,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states[:, :hidden_states_seq_len]
block_id += 1
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
output = self.final_layer(hidden_states, temb)
output = self.unpatchify(output, img_sizes, self.training)
if hidden_states_masks is not None:
hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output, hidden_states_masks)
return Transformer2DModelOutput(sample=output, mask=hidden_states_masks)

View File

@@ -221,6 +221,7 @@ else:
"EasyAnimateInpaintPipeline",
"EasyAnimateControlPipeline",
]
_import_structure["hidream_image"] = ["HiDreamImagePipeline"]
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
_import_structure["hunyuan_video"] = [
"HunyuanVideoPipeline",
@@ -585,6 +586,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxPriorReduxPipeline,
ReduxImageEncoder,
)
from .hidream_image import HiDreamImagePipeline
from .hunyuan_video import (
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoImageToVideoPipeline,

View File

@@ -0,0 +1,47 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["HiDreamImagePipelineOutput"]}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_hidream_image"] = ["HiDreamImagePipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_hidream_image import HiDreamImagePipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,739 @@
import inspect
import math
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
LlamaForCausalLM,
PreTrainedTokenizerFast,
T5EncoderModel,
T5Tokenizer,
)
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, HiDreamImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
from ...utils import is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import HiDreamImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
>>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline, HiDreamImageTransformer2DModel
>>> scheduler = UniPCMultistepScheduler(
... flow_shift=3.0, prediction_type="flow_prediction", use_flow_sigmas=True
... )
>>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
... "meta-llama/Meta-Llama-3.1-8B-Instruct",
... output_hidden_states=True,
... output_attentions=True,
... torch_dtype=torch.bfloat16,
... )
>>> transformer = HiDreamImageTransformer2DModel.from_pretrained(
... "HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16
... )
>>> pipe = HiDreamImagePipeline.from_pretrained(
... "HiDream-ai/HiDream-I1-Full",
... scheduler=scheduler,
... tokenizer_4=tokenizer_4,
... text_encoder_4=text_encoder_4,
... transformer=transformer,
... torch_dtype=torch.bfloat16,
... )
>>> pipe.enable_model_cpu_offload()
>>> image = pipe(
... 'A cat holding a sign that says "Hi-Dreams.ai".',
... height=1024,
... width=1024,
... guidance_scale=5.0,
... num_inference_steps=50,
... generator=torch.Generator("cuda").manual_seed(0),
... ).images[0]
>>> image.save("output.png")
```
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class HiDreamImagePipeline(DiffusionPipeline):
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
text_encoder_2: CLIPTextModelWithProjection,
tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel,
tokenizer_3: T5Tokenizer,
text_encoder_4: LlamaForCausalLM,
tokenizer_4: PreTrainedTokenizerFast,
transformer: HiDreamImageTransformer2DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
text_encoder_3=text_encoder_3,
text_encoder_4=text_encoder_4,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
tokenizer_3=tokenizer_3,
tokenizer_4=tokenizer_4,
scheduler=scheduler,
transformer=transformer,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
# HiDreamImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.default_sample_size = 128
if getattr(self, "tokenizer_4", None) is not None:
self.tokenizer_4.pad_token = self.tokenizer_4.eos_token
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder_3.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
text_inputs = self.tokenizer_3(
prompt,
padding="max_length",
max_length=min(max_sequence_length, self.tokenizer_3.model_max_length),
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_3.batch_decode(
untruncated_ids[:, min(max_sequence_length, self.tokenizer_3.model_max_length) - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return prompt_embeds
def _get_clip_prompt_embeds(
self,
tokenizer,
text_encoder,
prompt: Union[str, List[str]],
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=min(max_sequence_length, 218),
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(untruncated_ids[:, 218 - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {218} tokens: {removed_text}"
)
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return prompt_embeds
def _get_llama3_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder_4.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
text_inputs = self.tokenizer_4(
prompt,
padding="max_length",
max_length=min(max_sequence_length, self.tokenizer_4.model_max_length),
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_4.batch_decode(
untruncated_ids[:, min(max_sequence_length, self.tokenizer_4.model_max_length) - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {min(max_sequence_length, self.tokenizer_4.model_max_length)} tokens: {removed_text}"
)
outputs = self.text_encoder_4(
text_input_ids.to(device),
attention_mask=attention_mask.to(device),
output_hidden_states=True,
output_attentions=True,
)
prompt_embeds = outputs.hidden_states[1:]
prompt_embeds = torch.stack(prompt_embeds, dim=0)
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_3: Union[str, List[str]],
prompt_4: Union[str, List[str]],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt_3: Optional[Union[str, List[str]]] = None,
negative_prompt_4: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 128,
lora_scale: Optional[float] = None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]
prompt_embeds, pooled_prompt_embeds = self._encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_3=prompt_3,
prompt_4=prompt_4,
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
max_sequence_length=max_sequence_length,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
negative_prompt_3 = negative_prompt_3 or negative_prompt
negative_prompt_4 = negative_prompt_4 or negative_prompt
# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
negative_prompt_3 = (
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
)
negative_prompt_4 = (
batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4
)
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt(
prompt=negative_prompt,
prompt_2=negative_prompt_2,
prompt_3=negative_prompt_3,
prompt_4=negative_prompt_4,
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=negative_pooled_prompt_embeds,
max_sequence_length=max_sequence_length,
)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
def _encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_3: Union[str, List[str]],
prompt_4: Union[str, List[str]],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 128,
):
device = device or self._execution_device
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]
if pooled_prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype
)
pooled_prompt_embeds_2 = self._get_clip_prompt_embeds(
self.tokenizer_2, self.text_encoder_2, prompt_2, max_sequence_length, device, dtype
)
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
if prompt_embeds is None:
prompt_3 = prompt_3 or prompt
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
prompt_4 = prompt_4 or prompt
prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4
t5_prompt_embeds = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype)
llama3_prompt_embeds = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, device, dtype)
_, seq_len, _ = t5_prompt_embeds.shape
t5_prompt_embeds = t5_prompt_embeds.repeat(1, num_images_per_prompt, 1)
t5_prompt_embeds = t5_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
_, _, seq_len, dim = llama3_prompt_embeds.shape
llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, 1, num_images_per_prompt, 1)
llama3_prompt_embeds = llama3_prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim)
prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds]
return prompt_embeds, pooled_prompt_embeds
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
prompt_3: Optional[Union[str, List[str]]] = None,
prompt_4: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt_3: Optional[Union[str, List[str]]] = None,
negative_prompt_4: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 128,
):
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
division = self.vae_scale_factor * 2
S_max = (self.default_sample_size * self.vae_scale_factor) ** 2
scale = S_max / (width * height)
scale = math.sqrt(scale)
width, height = int(width * scale // division * division), int(height * scale // division * division)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
elif prompt_embeds is not None:
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]
else:
batch_size = 1
device = self._execution_device
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_3=prompt_3,
prompt_4=prompt_4,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
negative_prompt_3=negative_prompt_3,
negative_prompt_4=negative_prompt_4,
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if self.do_classifier_free_guidance:
prompt_embeds_arr = []
for n, p in zip(negative_prompt_embeds, prompt_embeds):
if len(n.shape) == 3:
prompt_embeds_arr.append(torch.cat([n, p], dim=0))
else:
prompt_embeds_arr.append(torch.cat([n, p], dim=1))
prompt_embeds = prompt_embeds_arr
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
pooled_prompt_embeds.dtype,
device,
generator,
latents,
)
if latents.shape[-2] != latents.shape[-1]:
B, C, H, W = latents.shape
pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size
img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1)
img_ids = torch.zeros(pH, pW, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :]
img_ids = img_ids.reshape(pH * pW, -1)
img_ids_pad = torch.zeros(self.transformer.max_seq, 3)
img_ids_pad[: pH * pW, :] = img_ids
img_sizes = img_sizes.unsqueeze(0).to(latents.device)
img_ids = img_ids_pad.unsqueeze(0).to(latents.device)
if self.do_classifier_free_guidance:
img_sizes = img_sizes.repeat(2 * B, 1)
img_ids = img_ids.repeat(2 * B, 1, 1)
else:
img_sizes = img_ids = None
# 5. Prepare timesteps
mu = calculate_shift(self.transformer.max_seq)
scheduler_kwargs = {"mu": mu}
if isinstance(self.scheduler, UniPCMultistepScheduler):
self.scheduler.set_timesteps(num_inference_steps, device=device) # , shift=math.exp(mu))
timesteps = self.scheduler.timesteps
else:
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
**scheduler_kwargs,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input,
timesteps=timestep,
encoder_hidden_states=prompt_embeds,
pooled_embeds=pooled_prompt_embeds,
img_sizes=img_sizes,
img_ids=img_ids,
return_dict=False,
)[0]
noise_pred = -noise_pred
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return HiDreamImagePipelineOutput(images=image)

View File

@@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class HiDreamImagePipelineOutput(BaseOutput):
"""
Output class for HiDreamImage pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]

View File

@@ -505,6 +505,21 @@ class FluxTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class HiDreamImageTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HunyuanDiT2DControlNetModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -617,6 +617,21 @@ class FluxPriorReduxPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class HiDreamImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class HunyuanDiTControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

View File

@@ -0,0 +1,156 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from transformers import (
AutoTokenizer,
CLIPTextConfig,
CLIPTextModelWithProjection,
CLIPTokenizer,
LlamaForCausalLM,
T5EncoderModel,
)
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
HiDreamImagePipeline,
HiDreamImageTransformer2DModel,
)
from diffusers.utils.testing_utils import enable_full_determinism
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = HiDreamImagePipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params
test_layerwise_casting = True
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = HiDreamImageTransformer2DModel(
patch_size=2,
in_channels=4,
out_channels=4,
num_layers=1,
num_single_layers=1,
attention_head_dim=8,
num_attention_heads=4,
caption_channels=[32, 16],
text_emb_dim=64,
num_routed_experts=4,
num_activated_experts=2,
axes_dims_rope=(4, 2, 2),
max_resolution=(32, 32),
llama_layers=(0, 1),
).eval()
torch.manual_seed(0)
vae = AutoencoderKL(scaling_factor=0.3611, shift_factor=0.1159)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
max_position_embeddings=128,
)
torch.manual_seed(0)
text_encoder = CLIPTextModelWithProjection(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
text_encoder_4 = LlamaForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
text_encoder_4.generation_config.pad_token_id = 1
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer_4 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
scheduler = FlowMatchEulerDiscreteScheduler()
components = {
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"text_encoder_3": text_encoder_3,
"tokenizer_3": tokenizer_3,
"text_encoder_4": text_encoder_4,
"tokenizer_4": tokenizer_4,
"transformer": transformer,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"output_type": "np",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs)[0]
generated_image = image[0]
self.assertEqual(generated_image.shape, (128, 128, 3))
expected_image = torch.randn(128, 128, 3).numpy()
max_diff = np.abs(generated_image - expected_image).max()
self.assertLessEqual(max_diff, 1e10)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-4)