mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Cogvideox-5B Model adapter change (#9203)
* draft of embedding --------- Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
@@ -29,6 +29,10 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
|
||||
|
||||
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
|
||||
|
||||
There are two models available that can be used with the CogVideoX pipeline:
|
||||
- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b)
|
||||
- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b)
|
||||
|
||||
## Inference
|
||||
|
||||
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
|
||||
@@ -68,7 +72,7 @@ With torch.compile(): Average inference time: 76.27 seconds.
|
||||
|
||||
### Memory optimization
|
||||
|
||||
CogVideoX requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script.
|
||||
CogVideoX-2b requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script.
|
||||
|
||||
- `pipe.enable_model_cpu_offload()`:
|
||||
- Without enabling cpu offloading, memory usage is `33 GB`
|
||||
|
||||
@@ -86,6 +86,9 @@ TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"key_layernorm_list": reassign_query_key_layernorm_inplace,
|
||||
"adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
|
||||
"embed_tokens": remove_keys_inplace,
|
||||
"freqs_sin": remove_keys_inplace,
|
||||
"freqs_cos": remove_keys_inplace,
|
||||
"position_embedding": remove_keys_inplace,
|
||||
}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {
|
||||
@@ -123,11 +126,21 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def convert_transformer(ckpt_path: str):
|
||||
def convert_transformer(
|
||||
ckpt_path: str,
|
||||
num_layers: int,
|
||||
num_attention_heads: int,
|
||||
use_rotary_positional_embeddings: bool,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
|
||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
||||
transformer = CogVideoXTransformer3DModel()
|
||||
transformer = CogVideoXTransformer3DModel(
|
||||
num_layers=num_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
|
||||
).to(dtype=dtype)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
@@ -145,9 +158,9 @@ def convert_transformer(ckpt_path: str):
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_vae(ckpt_path: str):
|
||||
def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
|
||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
||||
vae = AutoencoderKLCogVideoX()
|
||||
vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
@@ -172,13 +185,26 @@ def get_args():
|
||||
)
|
||||
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16")
|
||||
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
|
||||
parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
|
||||
)
|
||||
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
|
||||
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
|
||||
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
|
||||
parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
|
||||
# For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
|
||||
parser.add_argument(
|
||||
"--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
|
||||
)
|
||||
# For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
|
||||
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
|
||||
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
|
||||
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -188,18 +214,33 @@ if __name__ == "__main__":
|
||||
transformer = None
|
||||
vae = None
|
||||
|
||||
if args.fp16 and args.bf16:
|
||||
raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")
|
||||
|
||||
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
|
||||
|
||||
if args.transformer_ckpt_path is not None:
|
||||
transformer = convert_transformer(args.transformer_ckpt_path)
|
||||
transformer = convert_transformer(
|
||||
args.transformer_ckpt_path,
|
||||
args.num_layers,
|
||||
args.num_attention_heads,
|
||||
args.use_rotary_positional_embeddings,
|
||||
dtype,
|
||||
)
|
||||
if args.vae_ckpt_path is not None:
|
||||
vae = convert_vae(args.vae_ckpt_path)
|
||||
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
|
||||
|
||||
text_encoder_id = "google/t5-v1_1-xxl"
|
||||
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
||||
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
||||
|
||||
# Apparently, the conversion does not work any more without this :shrug:
|
||||
for param in text_encoder.parameters():
|
||||
param.data = param.data.contiguous()
|
||||
|
||||
scheduler = CogVideoXDDIMScheduler.from_config(
|
||||
{
|
||||
"snr_shift_scale": 3.0,
|
||||
"snr_shift_scale": args.snr_shift_scale,
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
@@ -208,7 +249,7 @@ if __name__ == "__main__":
|
||||
"prediction_type": "v_prediction",
|
||||
"rescale_betas_zero_snr": True,
|
||||
"set_alpha_to_one": True,
|
||||
"timestep_spacing": "linspace",
|
||||
"timestep_spacing": "trailing",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -218,5 +259,10 @@ if __name__ == "__main__":
|
||||
|
||||
if args.fp16:
|
||||
pipe = pipe.to(dtype=torch.float16)
|
||||
if args.bf16:
|
||||
pipe = pipe.to(dtype=torch.bfloat16)
|
||||
|
||||
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
|
||||
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
|
||||
# is either fp16/bf16 here).
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
|
||||
|
||||
@@ -1783,6 +1783,148 @@ class FluxAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class FusedCogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class XFormersAttnAddedKVProcessor:
|
||||
r"""
|
||||
Processor for implementing memory efficient attention using xFormers.
|
||||
|
||||
@@ -902,7 +902,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
Tuple of block output channels.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
||||
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
||||
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
|
||||
@@ -374,6 +374,90 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
return embeds
|
||||
|
||||
|
||||
def get_3d_rotary_pos_embed(
|
||||
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
RoPE for video tokens with 3D structure.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size, corresponding to hidden_size_head.
|
||||
crops_coords (`Tuple[int]`):
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int]`):
|
||||
The grid size of the spatial positional embedding (height, width).
|
||||
temporal_size (`int`):
|
||||
The size of the temporal dimension.
|
||||
theta (`float`):
|
||||
Scaling factor for frequency computation.
|
||||
use_real (`bool`):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
||||
"""
|
||||
start, stop = crops_coords
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||
|
||||
# Compute dimensions for each axis
|
||||
dim_t = embed_dim // 4
|
||||
dim_h = embed_dim // 8 * 3
|
||||
dim_w = embed_dim // 8 * 3
|
||||
|
||||
# Temporal frequencies
|
||||
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
|
||||
grid_t = torch.from_numpy(grid_t).float()
|
||||
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
|
||||
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
|
||||
|
||||
# Spatial frequencies for height and width
|
||||
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
|
||||
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
|
||||
grid_h = torch.from_numpy(grid_h).float()
|
||||
grid_w = torch.from_numpy(grid_w).float()
|
||||
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
|
||||
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
|
||||
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
|
||||
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
|
||||
|
||||
# Broadcast and concatenate tensors along specified dimension
|
||||
def broadcast(tensors, dim=-1):
|
||||
num_tensors = len(tensors)
|
||||
shape_lens = {len(t.shape) for t in tensors}
|
||||
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
||||
shape_len = list(shape_lens)[0]
|
||||
dim = (dim + shape_len) if dim < 0 else dim
|
||||
dims = list(zip(*(list(t.shape) for t in tensors)))
|
||||
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
||||
assert all(
|
||||
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
|
||||
), "invalid dimensions for broadcastable concatenation"
|
||||
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
||||
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
||||
expanded_dims.insert(dim, (dim, dims[dim]))
|
||||
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
||||
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
||||
return torch.cat(tensors, dim=dim)
|
||||
|
||||
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||
|
||||
t, h, w, d = freqs.shape
|
||||
freqs = freqs.view(t * h * w, d)
|
||||
|
||||
# Generate sine and cosine components
|
||||
sin = freqs.sin()
|
||||
cos = freqs.cos()
|
||||
|
||||
if use_real:
|
||||
return cos, sin
|
||||
else:
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
||||
"""
|
||||
RoPE for image tokens with 2d structure.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -22,6 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import is_torch_version, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, FeedForward
|
||||
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
||||
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -97,6 +98,7 @@ class CogVideoXBlock(nn.Module):
|
||||
eps=1e-6,
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
processor=CogVideoXAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 2. Feed Forward
|
||||
@@ -116,24 +118,24 @@ class CogVideoXBlock(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
|
||||
# attention
|
||||
text_length = norm_encoder_hidden_states.size(1)
|
||||
|
||||
# CogVideoX uses concatenated text + video embeddings with self-attention instead of using
|
||||
# them in cross-attention individually
|
||||
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
attn_output = self.attn1(
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_output[:, text_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length]
|
||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
||||
@@ -144,8 +146,9 @@ class CogVideoXBlock(nn.Module):
|
||||
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_ff * ff_output[:, text_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length]
|
||||
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@@ -231,6 +234,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
norm_eps: float = 1e-5,
|
||||
spatial_interpolation_scale: float = 1.875,
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
use_rotary_positional_embeddings: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
@@ -295,12 +299,113 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: Union[int, float, torch.LongTensor],
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
batch_size, num_frames, channels, height, width = hidden_states.shape
|
||||
@@ -319,14 +424,16 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||
|
||||
# 3. Position embedding
|
||||
seq_length = height * width * num_frames // (self.config.patch_size**2)
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
seq_length = height * width * num_frames // (self.config.patch_size**2)
|
||||
|
||||
pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length]
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
hidden_states = self.embedding_dropout(hidden_states)
|
||||
pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
hidden_states = self.embedding_dropout(hidden_states)
|
||||
|
||||
encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length]
|
||||
hidden_states = hidden_states[:, self.config.max_text_seq_length :]
|
||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
# 4. Transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
@@ -344,6 +451,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
emb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
@@ -351,9 +459,17 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
# CogVideoX-2B
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
else:
|
||||
# CogVideoX-5B
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
# 5. Final block
|
||||
hidden_states = self.norm_out(hidden_states, temb=emb)
|
||||
|
||||
@@ -23,6 +23,7 @@ from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
||||
from ...models.embeddings import get_3d_rotary_pos_embed
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
||||
from ...utils import BaseOutput, logging, replace_example_docstring
|
||||
@@ -40,6 +41,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import CogVideoXPipeline
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
|
||||
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
|
||||
>>> prompt = (
|
||||
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
|
||||
@@ -55,6 +57,25 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
||||
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
||||
tw = tgt_width
|
||||
th = tgt_height
|
||||
h, w = src
|
||||
r = h / w
|
||||
if r > (th / tw):
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h))
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
@@ -409,6 +430,46 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def fuse_qkv_projections(self) -> None:
|
||||
r"""Enables fused QKV projections."""
|
||||
self.fusing_transformer = True
|
||||
self.transformer.fuse_qkv_projections()
|
||||
|
||||
def unfuse_qkv_projections(self) -> None:
|
||||
r"""Disable QKV projection fusion if enabled."""
|
||||
if not self.fusing_transformer:
|
||||
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.transformer.unfuse_qkv_projections()
|
||||
self.fusing_transformer = False
|
||||
|
||||
def _prepare_rotary_positional_embeddings(
|
||||
self,
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
|
||||
grid_crops_coords = get_resize_crop_region_for_grid(
|
||||
(grid_height, grid_width), base_size_width, base_size_height
|
||||
)
|
||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||
embed_dim=self.transformer.config.attention_head_dim,
|
||||
crops_coords=grid_crops_coords,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=num_frames,
|
||||
use_real=True,
|
||||
)
|
||||
|
||||
freqs_cos = freqs_cos.to(device=device)
|
||||
freqs_sin = freqs_sin.to(device=device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
@@ -599,7 +660,14 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
# 7. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -620,6 +688,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
@@ -30,7 +30,12 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
from ..test_pipelines_common import (
|
||||
PipelineTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
check_qkv_fusion_processors_exist,
|
||||
to_np,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -279,6 +284,44 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
pass
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
frames = pipe(**inputs).frames # [B, F, C, H, W]
|
||||
original_image_slice = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
pipe.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(
|
||||
pipe.transformer
|
||||
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
frames = pipe(**inputs).frames
|
||||
image_slice_fused = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
pipe.transformer.unfuse_qkv_projections()
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
frames = pipe(**inputs).frames
|
||||
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user