mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add support for lumina2 (#10642)
* Add support for lumina2 --------- Co-authored-by: csuhan <hanjiaming@whu.edu.cn> Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: Aryan <aryan@huggingface.co> Co-authored-by: hlky <hlky@hlky.ac>
This commit is contained in:
@@ -290,6 +290,8 @@
|
||||
title: LatteTransformer3DModel
|
||||
- local: api/models/lumina_nextdit2d
|
||||
title: LuminaNextDiT2DModel
|
||||
- local: api/models/lumina2_transformer2d
|
||||
title: Lumina2Transformer2DModel
|
||||
- local: api/models/ltx_video_transformer3d
|
||||
title: LTXVideoTransformer3DModel
|
||||
- local: api/models/mochi_transformer3d
|
||||
@@ -442,6 +444,8 @@
|
||||
title: LEDITS++
|
||||
- local: api/pipelines/ltx_video
|
||||
title: LTXVideo
|
||||
- local: api/pipelines/lumina2
|
||||
title: Lumina 2.0
|
||||
- local: api/pipelines/lumina
|
||||
title: Lumina-T2X
|
||||
- local: api/pipelines/marigold
|
||||
|
||||
30
docs/source/en/api/models/lumina2_transformer2d.md
Normal file
30
docs/source/en/api/models/lumina2_transformer2d.md
Normal file
@@ -0,0 +1,30 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# Lumina2Transformer2DModel
|
||||
|
||||
A Diffusion Transformer model for 3D video-like data was introduced in [Lumina Image 2.0](https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0) by Alpha-VLLM.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import Lumina2Transformer2DModel
|
||||
|
||||
transformer = Lumina2Transformer2DModel.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## Lumina2Transformer2DModel
|
||||
|
||||
[[autodoc]] Lumina2Transformer2DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
33
docs/source/en/api/pipelines/lumina2.md
Normal file
33
docs/source/en/api/pipelines/lumina2.md
Normal file
@@ -0,0 +1,33 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# Lumina2
|
||||
|
||||
[Lumina Image 2.0: A Unified and Efficient Image Generative Model](https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0) is a 2 billion parameter flow-based diffusion transformer capable of generating diverse images from text descriptions.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We introduce Lumina-Image 2.0, an advanced text-to-image model that surpasses previous state-of-the-art methods across multiple benchmarks, while also shedding light on its potential to evolve into a generalist vision intelligence model. Lumina-Image 2.0 exhibits three key properties: (1) Unification – it adopts a unified architecture that treats text and image tokens as a joint sequence, enabling natural cross-modal interactions and facilitating task expansion. Besides, since high-quality captioners can provide semantically better-aligned text-image training pairs, we introduce a unified captioning system, UniCaptioner, which generates comprehensive and precise captions for the model. This not only accelerates model convergence but also enhances prompt adherence, variable-length prompt handling, and task generalization via prompt templates. (2) Efficiency – to improve the efficiency of the unified architecture, we develop a set of optimization techniques that improve semantic learning and fine-grained texture generation during training while incorporating inference-time acceleration strategies without compromising image quality. (3) Transparency – we open-source all training details, code, and models to ensure full reproducibility, aiming to bridge the gap between well-resourced closed-source research teams and independent developers.*
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Lumina2Text2ImgPipeline
|
||||
|
||||
[[autodoc]] Lumina2Text2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -118,6 +118,7 @@ else:
|
||||
"Kandinsky3UNet",
|
||||
"LatteTransformer3DModel",
|
||||
"LTXVideoTransformer3DModel",
|
||||
"Lumina2Transformer2DModel",
|
||||
"LuminaNextDiT2DModel",
|
||||
"MochiTransformer3DModel",
|
||||
"ModelMixin",
|
||||
@@ -338,6 +339,7 @@ else:
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXPipeline",
|
||||
"Lumina2Text2ImgPipeline",
|
||||
"LuminaText2ImgPipeline",
|
||||
"MarigoldDepthPipeline",
|
||||
"MarigoldNormalsPipeline",
|
||||
@@ -634,6 +636,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Kandinsky3UNet,
|
||||
LatteTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
Lumina2Transformer2DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
MochiTransformer3DModel,
|
||||
ModelMixin,
|
||||
@@ -833,6 +836,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXPipeline,
|
||||
Lumina2Text2ImgPipeline,
|
||||
LuminaText2ImgPipeline,
|
||||
MarigoldDepthPipeline,
|
||||
MarigoldNormalsPipeline,
|
||||
|
||||
@@ -72,6 +72,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
@@ -141,6 +142,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanVideoTransformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
Lumina2Transformer2DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
MochiTransformer3DModel,
|
||||
OmniGenTransformer2DModel,
|
||||
|
||||
@@ -612,7 +612,6 @@ class LuminaFeedForward(nn.Module):
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(2 * inner_dim / 3)
|
||||
# custom hidden_size factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
inner_dim = int(ffn_dim_multiplier * inner_dim)
|
||||
|
||||
@@ -219,14 +219,13 @@ class LuminaRMSNormZero(nn.Module):
|
||||
4 * embedding_dim,
|
||||
bias=True,
|
||||
)
|
||||
self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
||||
self.norm = RMSNorm(embedding_dim, eps=norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# emb = self.emb(timestep, encoder_hidden_states, encoder_mask)
|
||||
emb = self.linear(self.silu(emb))
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None])
|
||||
@@ -515,6 +514,16 @@ class RMSNorm(nn.Module):
|
||||
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
|
||||
if self.bias is not None:
|
||||
hidden_states = hidden_states + self.bias
|
||||
elif is_torch_version(">=", "2.4"):
|
||||
if self.weight is not None:
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
hidden_states = nn.functional.rms_norm(
|
||||
hidden_states, normalized_shape=(hidden_states.shape[-1],), weight=self.weight, eps=self.eps
|
||||
)
|
||||
if self.bias is not None:
|
||||
hidden_states = hidden_states + self.bias
|
||||
else:
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
|
||||
@@ -21,6 +21,7 @@ if is_torch_available():
|
||||
from .transformer_flux import FluxTransformer2DModel
|
||||
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
||||
from .transformer_ltx import LTXVideoTransformer3DModel
|
||||
from .transformer_lumina2 import Lumina2Transformer2DModel
|
||||
from .transformer_mochi import MochiTransformer3DModel
|
||||
from .transformer_omnigen import OmniGenTransformer2DModel
|
||||
from .transformer_sd3 import SD3Transformer2DModel
|
||||
|
||||
@@ -98,7 +98,7 @@ class LuminaNextDiTBlock(nn.Module):
|
||||
|
||||
self.feed_forward = LuminaFeedForward(
|
||||
dim=dim,
|
||||
inner_dim=4 * dim,
|
||||
inner_dim=int(4 * 2 * dim / 3),
|
||||
multiple_of=multiple_of,
|
||||
ffn_dim_multiplier=ffn_dim_multiplier,
|
||||
)
|
||||
|
||||
551
src/diffusers/models/transformers/transformer_lumina2.py
Normal file
551
src/diffusers/models/transformers/transformer_lumina2.py
Normal file
@@ -0,0 +1,551 @@
|
||||
# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import logging
|
||||
from ..attention import LuminaFeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 4096,
|
||||
cap_feat_dim: int = 2048,
|
||||
frequency_embedding_size: int = 256,
|
||||
norm_eps: float = 1e-5,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(
|
||||
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
|
||||
)
|
||||
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
|
||||
)
|
||||
|
||||
self.caption_embedder = nn.Sequential(
|
||||
RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
timestep_proj = self.time_proj(timestep).type_as(hidden_states)
|
||||
time_embed = self.timestep_embedder(timestep_proj)
|
||||
caption_embed = self.caption_embedder(encoder_hidden_states)
|
||||
return time_embed, caption_embed
|
||||
|
||||
|
||||
class Lumina2AttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
||||
used in the Lumina2Transformer2DModel model. It applies normalization and RoPE on query and key vectors.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
base_sequence_length: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
# Get Query-Key-Value Pair
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query_dim = query.shape[-1]
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = query_dim // attn.heads
|
||||
dtype = query.dtype
|
||||
|
||||
# Get key-value heads
|
||||
kv_heads = inner_dim // head_dim
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, kv_heads, head_dim)
|
||||
value = value.view(batch_size, -1, kv_heads, head_dim)
|
||||
|
||||
# Apply Query-Key Norm if needed
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
|
||||
|
||||
query, key = query.to(dtype), key.to(dtype)
|
||||
|
||||
# Apply proportional attention if true
|
||||
if base_sequence_length is not None:
|
||||
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
||||
else:
|
||||
softmax_scale = attn.scale
|
||||
|
||||
# perform Grouped-qurey Attention (GQA)
|
||||
n_rep = attn.heads // kv_heads
|
||||
if n_rep >= 1:
|
||||
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, scale=softmax_scale
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Lumina2TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
num_kv_heads: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: float,
|
||||
norm_eps: float,
|
||||
modulation: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_dim = dim // num_attention_heads
|
||||
self.modulation = modulation
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=dim // num_attention_heads,
|
||||
qk_norm="rms_norm",
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_kv_heads,
|
||||
eps=1e-5,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
processor=Lumina2AttnProcessor2_0(),
|
||||
)
|
||||
|
||||
self.feed_forward = LuminaFeedForward(
|
||||
dim=dim,
|
||||
inner_dim=4 * dim,
|
||||
multiple_of=multiple_of,
|
||||
ffn_dim_multiplier=ffn_dim_multiplier,
|
||||
)
|
||||
|
||||
if modulation:
|
||||
self.norm1 = LuminaRMSNormZero(
|
||||
embedding_dim=dim,
|
||||
norm_eps=norm_eps,
|
||||
norm_elementwise_affine=True,
|
||||
)
|
||||
else:
|
||||
self.norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
image_rotary_emb: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.modulation:
|
||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states + self.norm2(attn_output)
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
||||
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Lumina2RotaryPosEmbed(nn.Module):
|
||||
def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512), patch_size: int = 2):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
self.axes_lens = axes_lens
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.freqs_cis = self._precompute_freqs_cis(axes_dim, axes_lens, theta)
|
||||
|
||||
def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
|
||||
freqs_cis = []
|
||||
# Use float32 for MPS compatibility
|
||||
dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
|
||||
emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=dtype)
|
||||
freqs_cis.append(emb)
|
||||
return freqs_cis
|
||||
|
||||
def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
result = []
|
||||
for i in range(len(self.axes_dim)):
|
||||
freqs = self.freqs_cis[i].to(ids.device)
|
||||
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
|
||||
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
|
||||
return torch.cat(result, dim=-1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
|
||||
batch_size = len(hidden_states)
|
||||
p_h = p_w = self.patch_size
|
||||
device = hidden_states[0].device
|
||||
|
||||
l_effective_cap_len = attention_mask.sum(dim=1).tolist()
|
||||
# TODO: this should probably be refactored because all subtensors of hidden_states will be of same shape
|
||||
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
|
||||
l_effective_img_len = [(H // p_h) * (W // p_w) for (H, W) in img_sizes]
|
||||
|
||||
max_seq_len = max((cap_len + img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)))
|
||||
max_img_len = max(l_effective_img_len)
|
||||
|
||||
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
|
||||
|
||||
for i in range(batch_size):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
H, W = img_sizes[i]
|
||||
H_tokens, W_tokens = H // p_h, W // p_w
|
||||
assert H_tokens * W_tokens == img_len
|
||||
|
||||
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
|
||||
position_ids[i, cap_len : cap_len + img_len, 0] = cap_len
|
||||
row_ids = (
|
||||
torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
)
|
||||
col_ids = (
|
||||
torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
)
|
||||
position_ids[i, cap_len : cap_len + img_len, 1] = row_ids
|
||||
position_ids[i, cap_len : cap_len + img_len, 2] = col_ids
|
||||
|
||||
freqs_cis = self._get_freqs_cis(position_ids)
|
||||
|
||||
cap_freqs_cis_shape = list(freqs_cis.shape)
|
||||
cap_freqs_cis_shape[1] = attention_mask.shape[1]
|
||||
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
img_freqs_cis_shape = list(freqs_cis.shape)
|
||||
img_freqs_cis_shape[1] = max_img_len
|
||||
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
for i in range(batch_size):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
|
||||
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len : cap_len + img_len]
|
||||
|
||||
flat_hidden_states = []
|
||||
for i in range(batch_size):
|
||||
img = hidden_states[i]
|
||||
C, H, W = img.size()
|
||||
img = img.view(C, H // p_h, p_h, W // p_w, p_w).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
|
||||
flat_hidden_states.append(img)
|
||||
hidden_states = flat_hidden_states
|
||||
padded_img_embed = torch.zeros(
|
||||
batch_size, max_img_len, hidden_states[0].shape[-1], device=device, dtype=hidden_states[0].dtype
|
||||
)
|
||||
padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
|
||||
for i in range(batch_size):
|
||||
padded_img_embed[i, : l_effective_img_len[i]] = hidden_states[i]
|
||||
padded_img_mask[i, : l_effective_img_len[i]] = True
|
||||
|
||||
return (
|
||||
padded_img_embed,
|
||||
padded_img_mask,
|
||||
img_sizes,
|
||||
l_effective_cap_len,
|
||||
l_effective_img_len,
|
||||
freqs_cis,
|
||||
cap_freqs_cis,
|
||||
img_freqs_cis,
|
||||
max_seq_len,
|
||||
)
|
||||
|
||||
|
||||
class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
Lumina2NextDiT: Diffusion model with a Transformer backbone.
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`): The width of the latent images. This is fixed during training since
|
||||
it is used to learn a number of position embeddings.
|
||||
patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
|
||||
The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
|
||||
in_channels (`int`, *optional*, defaults to 4):
|
||||
The number of input channels for the model. Typically, this matches the number of channels in the input
|
||||
images.
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
|
||||
hidden representations.
|
||||
num_layers (`int`, *optional*, default to 32):
|
||||
The number of layers in the model. This defines the depth of the neural network.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
The number of attention heads in each attention layer. This parameter specifies how many separate attention
|
||||
mechanisms are used.
|
||||
num_kv_heads (`int`, *optional*, defaults to 8):
|
||||
The number of key-value heads in the attention mechanism, if different from the number of attention heads.
|
||||
If None, it defaults to num_attention_heads.
|
||||
multiple_of (`int`, *optional*, defaults to 256):
|
||||
A factor that the hidden size should be a multiple of. This can help optimize certain hardware
|
||||
configurations.
|
||||
ffn_dim_multiplier (`float`, *optional*):
|
||||
A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
|
||||
the model configuration.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5):
|
||||
A small value added to the denominator for numerical stability in normalization layers.
|
||||
scaling_factor (`float`, *optional*, defaults to 1.0):
|
||||
A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
|
||||
overall scale of the model's operations.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Lumina2TransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["x_embedder", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: int = 128,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
out_channels: Optional[int] = None,
|
||||
hidden_size: int = 2304,
|
||||
num_layers: int = 26,
|
||||
num_refiner_layers: int = 2,
|
||||
num_attention_heads: int = 24,
|
||||
num_kv_heads: int = 8,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
scaling_factor: float = 1.0,
|
||||
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
|
||||
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
||||
cap_feat_dim: int = 1024,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
|
||||
# 1. Positional, patch & conditional embeddings
|
||||
self.rope_embedder = Lumina2RotaryPosEmbed(
|
||||
theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size
|
||||
)
|
||||
|
||||
self.x_embedder = nn.Linear(in_features=patch_size * patch_size * in_channels, out_features=hidden_size)
|
||||
|
||||
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
|
||||
hidden_size=hidden_size, cap_feat_dim=cap_feat_dim, norm_eps=norm_eps
|
||||
)
|
||||
|
||||
# 2. Noise and context refinement blocks
|
||||
self.noise_refiner = nn.ModuleList(
|
||||
[
|
||||
Lumina2TransformerBlock(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
modulation=True,
|
||||
)
|
||||
for _ in range(num_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.context_refiner = nn.ModuleList(
|
||||
[
|
||||
Lumina2TransformerBlock(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
modulation=False,
|
||||
)
|
||||
for _ in range(num_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Lumina2TransformerBlock(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
modulation=True,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Output norm & projection
|
||||
self.norm_out = LuminaLayerNormContinuous(
|
||||
embedding_dim=hidden_size,
|
||||
conditioning_embedding_dim=min(hidden_size, 1024),
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
bias=True,
|
||||
out_dim=patch_size * patch_size * self.out_channels,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
use_mask_in_transformer: bool = True,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
# 1. Condition, positional & patch embedding
|
||||
temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
|
||||
|
||||
(
|
||||
hidden_states,
|
||||
hidden_mask,
|
||||
hidden_sizes,
|
||||
encoder_hidden_len,
|
||||
hidden_len,
|
||||
joint_rotary_emb,
|
||||
encoder_rotary_emb,
|
||||
hidden_rotary_emb,
|
||||
max_seq_len,
|
||||
) = self.rope_embedder(hidden_states, attention_mask)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
# 2. Context & noise refinement
|
||||
for layer in self.context_refiner:
|
||||
# NOTE: mask not used for performance
|
||||
encoder_hidden_states = layer(
|
||||
encoder_hidden_states, attention_mask if use_mask_in_transformer else None, encoder_rotary_emb
|
||||
)
|
||||
|
||||
for layer in self.noise_refiner:
|
||||
# NOTE: mask not used for performance
|
||||
hidden_states = layer(
|
||||
hidden_states, hidden_mask if use_mask_in_transformer else None, hidden_rotary_emb, temb
|
||||
)
|
||||
|
||||
# 3. Attention mask preparation
|
||||
mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
|
||||
padded_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
|
||||
for i in range(batch_size):
|
||||
cap_len = encoder_hidden_len[i]
|
||||
img_len = hidden_len[i]
|
||||
mask[i, : cap_len + img_len] = True
|
||||
padded_hidden_states[i, :cap_len] = encoder_hidden_states[i, :cap_len]
|
||||
padded_hidden_states[i, cap_len : cap_len + img_len] = hidden_states[i, :img_len]
|
||||
hidden_states = padded_hidden_states
|
||||
|
||||
# 4. Transformer blocks
|
||||
for layer in self.layers:
|
||||
# NOTE: mask not used for performance
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
layer, hidden_states, mask if use_mask_in_transformer else None, joint_rotary_emb, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = layer(hidden_states, mask if use_mask_in_transformer else None, joint_rotary_emb, temb)
|
||||
|
||||
# 5. Output norm & projection & unpatchify
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
|
||||
height_tokens = width_tokens = self.config.patch_size
|
||||
output = []
|
||||
for i in range(len(hidden_sizes)):
|
||||
height, width = hidden_sizes[i]
|
||||
begin = encoder_hidden_len[i]
|
||||
end = begin + (height // height_tokens) * (width // width_tokens)
|
||||
output.append(
|
||||
hidden_states[i][begin:end]
|
||||
.view(height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels)
|
||||
.permute(4, 0, 2, 1, 3)
|
||||
.flatten(3, 4)
|
||||
.flatten(1, 2)
|
||||
)
|
||||
output = torch.stack(output, dim=0)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -256,6 +256,7 @@ else:
|
||||
_import_structure["latte"] = ["LattePipeline"]
|
||||
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"]
|
||||
_import_structure["lumina"] = ["LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Text2ImgPipeline"]
|
||||
_import_structure["marigold"].extend(
|
||||
[
|
||||
"MarigoldDepthPipeline",
|
||||
@@ -597,6 +598,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
from .ltx import LTXImageToVideoPipeline, LTXPipeline
|
||||
from .lumina import LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Text2ImgPipeline
|
||||
from .marigold import (
|
||||
MarigoldDepthPipeline,
|
||||
MarigoldNormalsPipeline,
|
||||
|
||||
@@ -65,6 +65,7 @@ from .kandinsky2_2 import (
|
||||
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
|
||||
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
|
||||
from .lumina import LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Text2ImgPipeline
|
||||
from .pag import (
|
||||
HunyuanDiTPAGPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
@@ -135,6 +136,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("flux-control", FluxControlPipeline),
|
||||
("flux-controlnet", FluxControlNetPipeline),
|
||||
("lumina", LuminaText2ImgPipeline),
|
||||
("lumina2", Lumina2Text2ImgPipeline),
|
||||
("cogview3", CogView3PlusPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
48
src/diffusers/pipelines/lumina2/__init__.py
Normal file
48
src/diffusers/pipelines/lumina2/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_lumina2"] = ["Lumina2Text2ImgPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_lumina2 import Lumina2Text2ImgPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
770
src/diffusers/pipelines/lumina2/pipeline_lumina2.py
Normal file
770
src/diffusers/pipelines/lumina2/pipeline_lumina2.py
Normal file
@@ -0,0 +1,770 @@
|
||||
# Copyright 2024 Alpha-VLLM and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKL
|
||||
from ...models.transformers.transformer_lumina2 import Lumina2Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
if is_bs4_available():
|
||||
pass
|
||||
|
||||
if is_ftfy_available():
|
||||
pass
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import Lumina2Text2ImgPipeline
|
||||
|
||||
>>> pipe = Lumina2Text2ImgPipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16)
|
||||
>>> # Enable memory optimizations.
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> prompt = "Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures"
|
||||
>>> image = pipe(prompt).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.16,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class Lumina2Text2ImgPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Lumina-T2I.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`AutoModel`]):
|
||||
Frozen text-encoder. Lumina-T2I uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
|
||||
tokenizer (`AutoModel`):
|
||||
Tokenizer of class
|
||||
[AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
|
||||
transformer ([`Transformer2DModel`]):
|
||||
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: Lumina2Transformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: AutoModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 8
|
||||
self.default_sample_size = (
|
||||
self.transformer.config.sample_size
|
||||
if hasattr(self, "transformer") and self.transformer is not None
|
||||
else 128
|
||||
)
|
||||
self.default_image_size = self.default_sample_size * self.vae_scale_factor
|
||||
self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts."
|
||||
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
|
||||
if getattr(self, "tokenizer", None) is not None:
|
||||
self.tokenizer.padding_side = "right"
|
||||
|
||||
def _get_gemma_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 256,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
device = device or self._execution_device
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because Gemma can only handle sequences up to"
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
|
||||
)
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
elif self.transformer is not None:
|
||||
dtype = self.transformer.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_sequence_length: int = 256,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
|
||||
Lumina-T2I, this should be "".
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string.
|
||||
max_sequence_length (`int`, defaults to `256`):
|
||||
Maximum sequence length to use for the prompt.
|
||||
"""
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if system_prompt is None:
|
||||
system_prompt = self.system_prompt
|
||||
if prompt is not None:
|
||||
prompt = [system_prompt + " <Prompt Start> " + p for p in prompt]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
batch_size, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
# Get negative embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else ""
|
||||
|
||||
# Normalize str to list
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
batch_size, seq_len, _ = negative_prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.view(
|
||||
batch_size * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
||||
raise ValueError(
|
||||
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
num_inference_steps: int = 30,
|
||||
guidance_scale: float = 4.0,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
sigmas: List[float] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
system_prompt: Optional[str] = None,
|
||||
cfg_trunc_ratio: float = 1.0,
|
||||
cfg_normalization: bool = True,
|
||||
use_mask_in_transformer: bool = True,
|
||||
max_sequence_length: int = 256,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_inference_steps (`int`, *optional*, defaults to 30):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The width in pixels of the generated image.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For Lumina-T2I this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
system_prompt (`str`, *optional*):
|
||||
The system prompt to use for the image generation.
|
||||
cfg_trunc_ratio (`float`, *optional*, defaults to `1.0`):
|
||||
The ratio of the timestep interval to apply normalization-based guidance scale.
|
||||
cfg_normalization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to apply normalization-based guidance scale.
|
||||
use_mask_in_transformer (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use attention mask in `Lumina2Transformer2DModel`. Set `False` for performance gain.
|
||||
max_sequence_length (`int`, defaults to `256`):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is a list with the generated images
|
||||
"""
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
self._guidance_scale = guidance_scale
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
max_sequence_length=max_sequence_length,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
max_sequence_length=max_sequence_length,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
# 4. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
latent_channels,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
image_seq_len = latents.shape[1]
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# compute whether apply classifier-free truncation on this timestep
|
||||
do_classifier_free_truncation = (i + 1) / num_inference_steps > cfg_trunc_ratio
|
||||
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
|
||||
current_timestep = 1 - t / self.scheduler.config.num_train_timesteps
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
current_timestep = current_timestep.expand(latents.shape[0])
|
||||
|
||||
noise_pred_cond = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=current_timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
attention_mask=prompt_attention_mask,
|
||||
use_mask_in_transformer=use_mask_in_transformer,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform normalization-based guidance scale on a truncated timestep interval
|
||||
if self.do_classifier_free_guidance and not do_classifier_free_truncation:
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=current_timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
attention_mask=negative_prompt_attention_mask,
|
||||
use_mask_in_transformer=use_mask_in_transformer,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
# apply normalization after classifier-free guidance
|
||||
if cfg_normalization:
|
||||
cond_norm = torch.norm(noise_pred_cond, dim=-1, keepdim=True)
|
||||
noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
||||
noise_pred = noise_pred * (cond_norm / noise_norm)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
noise_pred = -noise_pred
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
@@ -531,6 +531,21 @@ class LTXVideoTransformer3DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class Lumina2Transformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LuminaNextDiT2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1142,6 +1142,21 @@ class LTXPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Lumina2Text2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LuminaText2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
89
tests/models/transformers/test_models_transformer_lumina2.py
Normal file
89
tests/models/transformers/test_models_transformer_lumina2.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import Lumina2Transformer2DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class Lumina2Transformer2DModelTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = Lumina2Transformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2 # N
|
||||
num_channels = 4 # C
|
||||
height = width = 16 # H, W
|
||||
embedding_dim = 32 # D
|
||||
sequence_length = 16 # L
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
timestep = torch.rand(size=(batch_size,)).to(torch_device)
|
||||
attention_mask = torch.ones(size=(batch_size, sequence_length), dtype=torch.bool).to(torch_device)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"timestep": timestep,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"sample_size": 16,
|
||||
"patch_size": 2,
|
||||
"in_channels": 4,
|
||||
"hidden_size": 24,
|
||||
"num_layers": 2,
|
||||
"num_refiner_layers": 1,
|
||||
"num_attention_heads": 3,
|
||||
"num_kv_heads": 1,
|
||||
"multiple_of": 2,
|
||||
"ffn_dim_multiplier": None,
|
||||
"norm_eps": 1e-5,
|
||||
"scaling_factor": 1.0,
|
||||
"axes_dim_rope": (4, 2, 2),
|
||||
"axes_lens": (128, 128, 128),
|
||||
"cap_feat_dim": 32,
|
||||
}
|
||||
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Lumina2Transformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
0
tests/pipelines/lumina2/__init__.py
Normal file
0
tests/pipelines/lumina2/__init__.py
Normal file
147
tests/pipelines/lumina2/test_pipeline_lumina2.py
Normal file
147
tests/pipelines/lumina2/test_pipeline_lumina2.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
Lumina2Text2ImgPipeline,
|
||||
Lumina2Transformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
class Lumina2Text2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
pipeline_class = Lumina2Text2ImgPipeline
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"height",
|
||||
"width",
|
||||
"guidance_scale",
|
||||
"negative_prompt",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
]
|
||||
)
|
||||
batch_params = frozenset(["prompt", "negative_prompt"])
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
|
||||
supports_dduf = False
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = Lumina2Transformer2DModel(
|
||||
sample_size=4,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
hidden_size=8,
|
||||
num_layers=2,
|
||||
num_attention_heads=1,
|
||||
num_kv_heads=1,
|
||||
multiple_of=16,
|
||||
ffn_dim_multiplier=None,
|
||||
norm_eps=1e-5,
|
||||
scaling_factor=1.0,
|
||||
axes_dim_rope=[4, 2, 2],
|
||||
cap_feat_dim=8,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
block_out_channels=(4,),
|
||||
layers_per_block=1,
|
||||
latent_channels=4,
|
||||
norm_num_groups=1,
|
||||
use_quant_conv=False,
|
||||
use_post_quant_conv=False,
|
||||
shift_factor=0.0609,
|
||||
scaling_factor=1.5035,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = GemmaConfig(
|
||||
head_dim=2,
|
||||
hidden_size=8,
|
||||
intermediate_size=37,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=2,
|
||||
num_key_value_heads=4,
|
||||
)
|
||||
text_encoder = GemmaForCausalLM(config)
|
||||
|
||||
components = {
|
||||
"transformer": transformer.eval(),
|
||||
"vae": vae.eval(),
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder.eval(),
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_lumina_prompt_embeds(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
output_with_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = inputs.pop("prompt")
|
||||
|
||||
do_classifier_free_guidance = inputs["guidance_scale"] > 1
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = pipe.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
device=torch_device,
|
||||
)
|
||||
output_with_embeds = pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
**inputs,
|
||||
).images[0]
|
||||
|
||||
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
|
||||
assert max_diff < 1e-4
|
||||
Reference in New Issue
Block a user